%reload_ext autoreload
%autoreload 2
%matplotlib inline
from fastai.vision import *
from fastai.callbacks.hooks import *
from fastai.utils.mem import *
from pathlib import Path
from functools import partial
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
import re
import random
#/hpf/largeprojects/MICe/mdagys/Cnp-GFP_Study/2019-06-10_labelled/raw
raw_dir = Path("raw")
raws = raw_dir.ls()
images = sorted([raw_path for raw_path in raws if "_image" in raw_path.name])
labels = sorted([raw_path for raw_path in raws if "_label" in raw_path.name])
# D-R_Z were the initial ones to be labelled, kinda more sloppy.
# images = sorted([raw_path for raw_path in raws if "_image" in raw_path.name and "D-R_Z" not in raw_path.name])
# labels = sorted([raw_path for raw_path in raws if "_label" in raw_path.name and "D-R_Z" not in raw_path.name])
processed_dir = Path("processed")
l=224
random.seed(23)
empty = 0
popu = 0
cutoff=1
for image_path,label_path in zip(images,labels):
image = cv.imread(image_path.as_posix(), cv.COLOR_BGR2GRAY)
label = cv.imread(label_path.as_posix(), cv.COLOR_BGR2GRAY)
if image.shape != label.shape:
raise ValueError(image_path.as_posix() + label_path.as_posix())
i_max = image.shape[0]//l
j_max = image.shape[1]//l
# If the cells were labelled as 255, or something else mistakenly, instead of 1.
label[label!=0]=1
for i in range(i_max):
for j in range(j_max):
cropped_image = image[l*i:l*(i+1), l*j:l*(j+1)]
cropped_label = label[l*i:l*(i+1), l*j:l*(j+1)]
if (cropped_label!=0).any():
popu+=1
cropped_image_path = processed_dir/(image_path.stem + "_i" + str(i) + "_j" + str(j) + image_path.suffix)
cropped_label_path = processed_dir/(label_path.stem + "_i" + str(i) + "_j" + str(j) + label_path.suffix)
else:
empty+=1
if (random.random() < cutoff):
continue
cropped_image_path = processed_dir/(image_path.stem + "_i" + str(i) + "_j" + str(j) + "_empty" + image_path.suffix)
cropped_label_path = processed_dir/(label_path.stem + "_i" + str(i) + "_j" + str(j) + "_empty" + label_path.suffix)
cv.imwrite(cropped_image_path.as_posix(), cropped_image)
cv.imwrite(cropped_label_path.as_posix(), cropped_label)
print(popu)
print(empty)
torch.cuda.set_device(0)
codes = ["NOT-CELL", "CELL"]
bs = 4
#bs=16 and l=224 will use ~7300MiB for resnet34 before unfreezing
#bs=4 and l=224 use ~11500MiB for resnet50 before unfreezing
transforms = get_transforms(
do_flip = True,
flip_vert = True,
max_zoom = 1, #consider
max_rotate = 45,
max_lighting = None,
max_warp = None,
p_affine = 0.75,
p_lighting = 0.75)
get_label_from_image = lambda path: re.sub(r'_image_', '_label_', path.as_posix())
src = (
SegmentationItemList.from_folder(processed_dir)
.filter_by_func(lambda fname:
'image' in Path(fname).name and "empty" not in Path(fname).name)
.split_by_rand_pct(valid_pct=0.20, seed=1)
.label_from_func(get_label_from_image, classes=codes)
)
data = (
src.transform(transforms, tfm_y=True)
.databunch(bs=bs)
.normalize(imagenet_stats)
)
data.show_batch(2, figsize=(10,7))
# models.resnet34
model_path = Path("../../models")
learn = unet_learner(data, models.resnet50, metrics=partial(dice, iou=True))
learn.loss_func = CrossEntropyFlat(axis=1, weight = torch.Tensor([1,1]).cuda())
lr_find(learn)
learn.recorder.plot()
lr = 2e-3
learn.fit_one_cycle(15, lr)
learn.save(model_path/"2019-07-02_RESNET50_IOU0.41_1stage")
!jupyter nbconvert gfp-cnp-train.ipynb --to html --output nbs/2019-06-26_RESNET50_IOU0.41_1stage
learn.load(model_path/"2019-07-02_RESNET50_IOU0.41_1stage");
learn.freeze_to(-2)
lr_find(learn)
learn.recorder.plot()
lr=1e-5
lrs = slice(lr/1000,lr/10)
learn.fit_one_cycle(20, lrs)
learn.save(models_path/"2019-06-14_RESNET34_IOU0.25_2stage")
learn.export(file = models_path/"2019-06-14_RESNET34_IOU0.25_2stage.pkl")
print(learn.data.valid_ds.__len__()) #list of N
print(learn.data.valid_ds[0]) #tuple of input image and segment
print(learn.data.valid_ds[0][1])
# print(learn.data.valid_ds.__len__())
# type(learn.data.valid_ds[0][0])
# preds = learn.get_preds(with_loss=True)
preds = learn.get_preds()
print(len(preds)) # tuple of list of probs and targets
print(preds[0].shape) #predictions
print(preds[0][0].shape) #probabilities for each label
print(learn.data.classes) #what is each label
print(preds[0][0][0].shape) #probabilities for label 0
# for i in range(0,N):
# print(torch.max(preds[0][i][1]))
# Image(preds[1][0]).show()
if learn.data.valid_ds.__len__() == preds[1].shape[0]:
N = learn.data.valid_ds.__len__()
else:
raise ValueError()
xs = [learn.data.valid_ds[i][0] for i in range(N)]
ys = [learn.data.valid_ds[i][1] for i in range(N)]
p0s = [Image(preds[0][i][0]) for i in range(N)]
p1s = [Image(preds[0][i][1]) for i in range(N)]
argmax = [Image(preds[0][i].argmax(dim=0)) for i in range(N)]
print(xs[0].px.shape)
print(ys[0].px.shape)
print(p0s[0].px.shape)
print(p1s[0].px.shape)
ncol = 3
nrow = N//ncol + 1
fig=plt.figure(figsize=(12, nrow*5))
for i in range(1,N):
fig.add_subplot(nrow, ncol, i)
# plt.imshow(xs[i-1].px.permute(1, 2, 0), cmap = "Oranges", alpha=0.5)
plt.imshow(argmax[i-1].px, cmap = "Blues", alpha=0.5)
# plt.imshow(p1s[i-1].px, cmap = "Blues", alpha=0.5)
plt.imshow(ys[i-1].px[0], cmap = "Oranges", alpha=0.5)
plt.show()
fig=plt.figure(figsize=(12, nrow*5))
for i in range(1,N):
fig.add_subplot(nrow, ncol, i)
plt.imshow(xs[i-1].px.permute(1, 2, 0), cmap = "Greys", alpha=1)
plt.imshow(argmax[i-1].px, cmap = "Blues", alpha=0.5)
# plt.imshow(p1s[i-1].px, cmap = "Blues", alpha=0.5)
# plt.imshow(ys[i-1].px[0], cmap = "Oranges", alpha=0.5)
plt.show()